/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

KERNEL_SPEC(batch_invert_kernel_0, BLOCK_SIZE)
(
    GLOBAL uint32_t* us_inv,
    GLOBAL const uint32_t* __restrict__ us,
    GLOBAL const size_t* __restrict__ inv_batch_size
) {
    const size_t gid = global_id();
    const size_t gsize = global_size();
    const size_t fe_n = *inv_batch_size * gsize;

    FieldElement c0, d0;
    fe_load(c0, us, fe_n, gid);
    fe_store(us_inv, fe_n, c0, gid);

    for(size_t i = 1; i < *inv_batch_size; ++i) {
        fe_load(d0, us, fe_n, i * gsize + gid);
        fe_mul(c0, d0, c0);
        fe_store(us_inv, fe_n, c0, i * gsize + gid);
    }
}

KERNEL_SPEC(batch_invert_kernel_1, BLOCK_SIZE)
(
    GLOBAL uint32_t* us_inv,
    GLOBAL const size_t* __restrict__ inv_batch_size
) {
    const size_t gid = global_id();
    const size_t gsize = global_size();
    const size_t fe_n = *inv_batch_size * gsize;

    FieldElement u;
    fe_load(u, us_inv, fe_n, (*inv_batch_size - 1) * gsize + gid);
    fe_invert(u, u);
    fe_store(us_inv, fe_n, u, (*inv_batch_size - 1) * gsize + gid);
}

KERNEL_SPEC(batch_invert_kernel_2, BLOCK_SIZE)
(
    GLOBAL uint32_t* us_inv,
    GLOBAL const uint32_t* __restrict__ us,
    GLOBAL const size_t* __restrict__ inv_batch_size
) {
    const size_t gid = global_id();
    const size_t gsize = global_size();
    const size_t fe_n = *inv_batch_size * gsize;

    FieldElement c, t, v;
    fe_load(c, us_inv, fe_n, gsize * (*inv_batch_size - 1) + gid);

    for(int i = *inv_batch_size - 1; i; --i) {
        fe_load(t, us_inv, fe_n, gsize * (uint32_t)(i - 1) + gid);
        fe_mul(t, c, t);
        fe_load(v, us, fe_n, gsize * (uint32_t)(i) + gid);
        fe_store(us_inv, fe_n, t, gsize * (uint32_t)(i) + gid);
        fe_mul(c, v, c);
    }

    fe_store(us_inv, fe_n, c, gid);
}
